{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Adaptivity of Posterior Variance with Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import argparse\n",
    "import os, sys\n",
    "import time\n",
    "import tabulate\n",
    "\n",
    "import torch\n",
    "import torch.nn.functional as F\n",
    "import torchvision\n",
    "import numpy as np\n",
    "import tqdm\n",
    "\n",
    "from matplotlib import pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "from swag import data, models, utils, losses\n",
    "from swag.posteriors import SWAG, SGLD\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Files already downloaded and verified\n",
      "You are going to run models on the test set. Are you sure?\n",
      "Files already downloaded and verified\n"
     ]
    }
   ],
   "source": [
    "model_cfg = models.PreResNet20NoAug\n",
    "\n",
    "loaders, num_classes = data.loaders(\n",
    "    \"CIFAR10\",\n",
    "    \"~/datasets/\",\n",
    "    10000,\n",
    "    4,\n",
    "    model_cfg.transform_train,\n",
    "    model_cfg.transform_test,\n",
    "    use_validation=False,\n",
    "    split_classes=None,\n",
    "    shuffle_train=False\n",
    ")\n",
    "loader = loaders[\"test\"]\n",
    "\n",
    "model = model_cfg.base(*model_cfg.args, num_classes=num_classes, **model_cfg.kwargs)\n",
    "model.cuda();"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "def resample(model, sigma=1):\n",
    "    for p in model.parameters():\n",
    "        p.copy_(torch.randn_like(p) * sigma)\n",
    "        \n",
    "def softmax(arr, axis=-1):\n",
    "    arr_ = arr - np.max(arr, axis=axis, keepdims=True)\n",
    "    return np.exp(arr_) / np.sum(np.exp(arr_), axis=axis, keepdims=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "idx = []\n",
    "y = np.array(loaders['train'].dataset.targets)\n",
    "perclass = 100\n",
    "\n",
    "for cls in range(10):\n",
    "    idx += np.where(y == cls)[0][:perclass].tolist()\n",
    "loaders['train'].dataset.data = loaders['train'].dataset.data[idx]\n",
    "loaders['train'].dataset.targets = np.array(loaders['train'].dataset.targets)[idx]\n",
    "\n",
    "assert len(loaders['train'].dataset) == 10 * perclass"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test Data\n",
    "We will visualize posterior samples after observing `perclass * 10` datapoints"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for x, y in loader:\n",
    "    break\n",
    "\n",
    "mask = np.zeros_like(y).astype(bool)\n",
    "classes = np.arange(10)\n",
    "perclass = 20\n",
    "\n",
    "for cls in classes:\n",
    "    mask_cls = np.where(y == cls)[0]\n",
    "    mask[mask_cls[:perclass]] = True\n",
    "\n",
    "x = x[mask].cuda()\n",
    "y = y[mask]\n",
    "\n",
    "idx = np.argsort(y)\n",
    "x_test = x[idx]\n",
    "y_test = y[idx]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prior Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_outputs = []\n",
    "sigma = np.sqrt(10)\n",
    "samples = 500\n",
    "\n",
    "with torch.no_grad():\n",
    "    for i in range(samples):\n",
    "        resample(model, sigma=sigma)\n",
    "        outputs = model(x_test).data.cpu().numpy()[:, :, None]\n",
    "        all_outputs.append(outputs)\n",
    "all_outputs = np.dstack(all_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 10 artists>"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAA3AAAAE/CAYAAAAHeyFHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAgAElEQVR4nO3df7icZ13n8feHpClYoGAb9oKmIZUGNQgreEi5FoQuv0ypNioFE9a11UpklyoKqKm4pRR0C3qB7lKVIJUKS0Oti3u0wcDyQy5ZWhOgFJNSOIRAE8CGthQrlBL47h/zpJ1Mz49JMmfmPGfer+uaq8+Pe2a+Z/Kd0/N97vu571QVkiRJkqSF7wGjDkCSJEmS1B8LOEmSJElqCQs4SZIkSWoJCzhJkiRJagkLOEmSJElqCQs4SZIkSWoJC7gWSfL2JK/rs+2qJJVkabP/3iTnzW+E0pE5kpyWFgJzVm1jzqqNzNvZWcCNiao6q6qunKtdU/SdPoyYpGPRXKT4UJJvJvlMkmePOiZpNklem+TTSQ4muWTU8UizSfKIJFcl+XKSO5N8NMkZo45Lmkvzt8GBJN9I8qkk60cd06BZwI3IoZ4xSUftKuCTwEnAq4BrkiwfbUjSrKaA3wKuHXUgUh8eDOwAfgz4fuBK4NokDx5pVNLcXgY8sqoeCmwC3pnkkSOOaaAs4AYoyd4kFyXZneSOJH+R5IHNuTOT7Evy20m+CvxFc/wnk9yQ5OtJ/l+SJ3S93hOTfCLJvyZ5N/DAWd57SZI/TPK1JHuAs3vOfzjJLzfbpyf5h+aK2tea1ybJR5rmn0pyV5KfG+Tno4WlycX9TX7dnORZzfG1ST7W5ORXkrw5ybKu51WS/5rkc81zX5vkMU3+fiPJ1Yfad+X97zS5tjfJf5olphm/Dz3tHgs8CXh1VX2rqv4a+DTw/EF+RlpY2pyzAFV1ZVW9F/jXAX4sWsDanLNVtaeq3lhVX6mq71bVFmAZ8IOD/ZS00LQ5bwGq6saqOnhoFzgOOHUgH85CUVU+BvQA9gL/TCdJvh/4KPC65tyZwEHg9cDxwIOAJwK3AmcAS4Dzmtc4ns4vyS8Cv0En8c4FvnPo9aZ575cAn+l67w/RSdqlzfkPA7/cbF9Fp8fiAXSKwqd1vU4Bp4/6s/Qx77n6g8AtwKOa/VXAY5rtHwOeAixtjt8E/HpPjvwf4KHA44BvAx8AfgA4EdgNnNe0PZT3b2zy+hnAvwE/2Jx/e9d3ZMbvwzTx/wxwU8+xNwP/c9SfrQ9zdrqc7flZ3glcMurP1Ic522/ONs/9UeBu4MRRf7Y+zNu58hb4uyZfC/h74AGj/mwH+bAHbvDeXFW3VNXtwO8BG7vOfY9Oj8G3q+pbdLp131JV11fn6taVdJL9Kc3jOOCPquo7VXUNnaEMM3lh0/bQe//3Wdp+B3g0nS/n3VX1j0f7w6q1vkvnF+aaJMdV1d6q+jxAVX28qq6rqoNVtRd4C51frN3eUFXfqKpddC5avK86V2vvBN5L55dtt//W5P0/0Bk+9sJpYprt+9DrwcCdPcfuBB7S58+v9ml7zmr8LJqcTfJQ4B3Aa5r31+K1KPK2qn6Szt8Ez2ti+N6RfQwLmwXc4N3Stf1F4FFd+weq6u6u/UcDr2i6g7+e5Ot0etAe1Tz2V3UuI3S93kweNc17z+S3gAD/lGRXkl+apa0WoaqaAn4duAS4NcnWJI+CzvDEJH+X5KtJvgH8PnByz0v8S9f2t6bZ775H4o6q+reu/d7vxSGzfR963UXnCl+3h+LQtEVrEeSsxsxiydkkDwL+Friuqma7OKxFYLHkbfOzfKc6w9afm+Sc2dq2jQXc4HWPsV0JfLlrv3ra3gL8XlU9rOvxfVV1FfAV4JQk6Xm9mXxlmveeVlV9tapeXFWPAn4F+JM48+TYqap3VdXT6PxiLDrDewH+lM5w3NXVuQH4d+gU/Efr4UlO6Nrv/V4cMtv3odcu4AeSdPe4/fvmuBapluesxlDbczbJ8cDfAPvo/L2gMdD2vJ3GUuAxxxDngmMBN3gvTbIiyffTuc/s3bO0fSvwkiRnpOOEJGc3f5R+jM7Y4F9LclySnwXWzvJaVzdtVyR5OLB5poZJXpBkRbN7B50v56Gu5X+hM1ZZi1iSH0zyzOZ/znfTuSp2KAceAnwDuCvJDwH/ZQBv+Zoky5L8OPCTwF9N02a278NhquqzwA3Aq5M8MMnPAE8A/noAsWoBanvONj/DcelMbPUAYGmTu0sGEKsWoLbnbJLjgGuauM9bbEPQNL1FkLc/lOSsJA9qfuf+PPB04B8GEOuCYQE3eO8C3gfsAT4PzLgIYVXtBF5MZ/KFO+hMMX1+c+4e4Geb/duBnwP+9yzv+1ZgO/Ap4BNztH0ycH2Su4BJ4GVVtac5dwlwZdNFPd04ZC0OxwOXAV8Dvgo8ArioOfdK4EV0hiO+ldkvQvTjq3Ty+8vA/wJeUlWf6W002/dhBhuAiabtZcC5VXXgGGPVwrUYcvatdP4Y2kjnAt+3gP98jLFq4Wp7zv4HOn9QPxf4ejqzU9/V/KGtxavteRua4Z/AATpLCvxcVX3iGGNdUHL4LVY6Fkn20pnp8f+OOhZpIUhyJvDOqloxV1tpITBn1TbmrNrIvD029sBJkiRJUktYwEmSJElSSziEUpIkDUySK+jcO3VrVf3INOcD/DGd9Zm+CZy/2O5PkaT5ZA+cJEkapLcD62Y5fxawunlsojM1uSSpTxZwkiRpYKrqI3RmT57JeuAvq+M64GFJHjmc6CSp/ZaOOoBeJ598cq1atWrUYajFPv7xj3+tqpYP8z3NWx0Lc1Ztc4w5ewqdhXkP2dcc+8psTzJndSzmytkk6+gM7V0C/HlVXdZz/k3Af2x2vw94RFU9bLb3NGd1LGbL2QVXwK1atYqdO3eOOgy1WJIvDvs9zVsdC3NWbTOsnE2yic4wS1auXGnO6qjNlrNJlgCXA8+hc0FhR5LJqtp9qE1V/UZX+18FnjjXe/p7Vsditpx1CKUkSRqm/cCpXfsrmmP3U1VbqmqiqiaWLx9qJ7XGy1pgqqr2VNU9wFY6Q31nshG4aiiRSdOwgJMkScM0CfxCOp4C3FlVsw6flObZTMN67yfJo4HTgA/OcH5Tkp1Jdh44cGDggUqwAIdQSpKk9kpyFXAmcHKSfcCrgeMAqurPgG10lhCYorOMwC+OJlLpqGwArqmq7053sqq2AFsAJiYmXKtL88ICTpIkDUxVbZzjfAEvHVI4Uj/6HtZLp4AzfzVSDqGUJEnSONsBrE5yWpJldIq0yd5GSX4IeDjwsSHHJx3GAk6SJEljq6oOAhcC24GbgKuraleSS5Oc09V0A7C16UWWRsYhlJIkSRprVbWNzv2Z3ccu7tm/ZJgxSTOxB05jJckVSW5N8s8znE+S/5FkKsmNSZ407BilbuasJEnq1lcBl2RdkpubPxA2T3P+6Uk+keRgknO7jv9oko8l2dX8YfFzgwxeOgpvB9bNcv4sYHXz2AT86RBikmbzdsxZSZLUmLOA61qd/ixgDbAxyZqeZl8Czgfe1XP8m8AvVNXj6PwB8kdJHnasQUtHq6o+Atw+S5P1wF9Wx3XAw5I8cjjRSfdnzkqSpG799MDNuTp9Ve2tqhuB7/Uc/2xVfa7Z/jJwK7B8IJFL86PvxTylBcKclSRpjPQzicl0fxyccaRvlGQtsAz4/DTnNtEZ+sPKlSuP9KU1z1ZtvravdnsvO3ueI1lYxjlvzYl2Wmg5228egbkkzSe/i+3gv5MOGcokJs1wnncAv1hV3+s9X1VbqmqiqiaWL7eDTiPV92Ke5q0WCHNWkqQx0k8BdySr099PkocC1wKvau7PkBaySeAXmpn9ngLcWVVfGXVQ0izMWUmSxkg/QyjvXZ2eTuG2AXhRPy/erGb/Hjo32F9z1FFKA5LkKuBM4OQk+4BXA8cBVNWf0VkD5nnAFJ1JeH5xNJFKHeasJEnqNmcBV1UHkxxanX4JcMWh1emBnVU1meTJdAq1hwM/leQ1zcyTLwSeDpyU5PzmJc+vqhvm44eR5lJVG+c4X8BLhxSONCdzVpIkdeunB27O1emragedoZW9z3sn8M5jjFGSJEmSxJAmMZEkSZIkHTsLOEmSJElqCQs4SZIkSWoJCzhJkiRJagkLOEmSJElqCQs4SZIkSWoJCzhJkiRJaom+1oFbKFZtvrbvtnsvO3seI5EkSZKk4bMHTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJaYumoA5hvqzZf23fbvZedPY+RSJIkSdKxWfQFnNrDYluSJI1CknXAHwNLgD+vqsumafNC4BKggE9V1YuGGqTUsICTJEnS2EqyBLgceA6wD9iRZLKqdne1WQ1cBDy1qu5I8ojRRCt5D5wkSZLG21pgqqr2VNU9wFZgfU+bFwOXV9UdAFV165BjlO5lASdJkqRxdgpwS9f+vuZYt8cCj03y0STXNUMupZFwCKUkSZI0u6XAauBMYAXwkSSPr6qvdzdKsgnYBLBy5cphx6gxYQ+cJEmSxtl+4NSu/RXNsW77gMmq+k5VfQH4LJ2C7jBVtaWqJqpqYvny5fMWsMabBZwkSZLG2Q5gdZLTkiwDNgCTPW3+hk7vG0lOpjOkcs8wg5QOsYCTJEnS2Kqqg8CFwHbgJuDqqtqV5NIk5zTNtgO3JdkNfAj4zaq6bTQRa9x5D5wkSZLGWlVtA7b1HLu4a7uAlzcPaaTsgZMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJA5VkXZKbk0wl2TzN+ZVJPpTkk0luTPK8UcQpSW1kASdJkgYmyRLgcuAsYA2wMcmanma/S2eq9ifSWXPrT4YbpSS1V18FXB9X0p6e5BNJDiY5t+fceUk+1zzOG1TgkiRpQVoLTFXVnqq6B9gKrO9pU8BDm+0TgS8PMT5JarU514HrupL2HGAfsCPJZFXt7mr2JeB84JU9z/1+4NXABJ1f1h9vnnvHYMKXJEkLzCnALV37+4AzetpcArwvya8CJwDPHk5oktR+/fTAzXklrar2VtWNwPd6nvsTwPur6vamaHs/sG4AcUuSpPbaCLy9qlYAzwPekeR+f5Mk2ZRkZ5KdBw4cGHqQkrQQ9VPATXcl7ZQ+X7+v5/oLWpKkRWM/cGrX/ormWLcLgKsBqupjwAOBk3tfqKq2VNVEVU0sX758nsKVpHZZEJOY+AtakqRFYwewOslpSZbRmaRksqfNl4BnAST5YToFnFdwJakP/RRw/VxJm4/nSpKklqmqg8CFwHbgJjqzTe5KcmmSc5pmrwBenORTwFXA+VVVo4lYktplzklM6LqSRqf42gC8qM/X3w78fpKHN/vPBS464iglSVJrVNU2YFvPsYu7tncDTx12XJK0GMzZA9fPlbQkT06yD3gB8JYku5rn3g68lk4RuAO4tDkmSZIkSTpCfd0DV1XbquqxVfWYqvq95tjFVTXZbO+oqhVVdUJVnVRVj+t67hVVdXrz+Iv5+TGk/vSxpuHKJB9K8skkNyZ53ijilLqZt5Ik6ZAFMYmJNAxdaxqeBawBNiZZ09Psd+n0Mj+RznDhPxlulNLhzFtJktTNAk7jZM41DeksOP/QZvtE4MtDjE+ajnkrSZLu1c8kJtJiMd26hGf0tLkEeF+SXwVOAJ49nNCkGZm3kiTpXvbASYfbCLy9qlYAzwPekWTa74kL0GsB6StvzVlJktrPHjjNi1Wbr+277d7Lzp7HSA7Tz7qEFwDrAKrqY0keCJwM3Nr7YlW1BdgCMDEx4fpFmi8Dy1tzVpKk9rMHTuPk3jUNkyyjM9nDZE+bLwHPAkjyw8ADAbsqNErmrSRJupcFnMZGP2saAq8AXpzkU8BVwPlVZU+FRsa8lSRJ3RxCuQD0O9xwiEMNF62q2gZs6zl2cdf2buCpw45Lmo15K0mSDrGAa6kFeo+ZJEmSpHlkATdAFlWSJEmS5pP3wEmSJElSS1jASZIkSVJLWMBJkiRJUktYwEmSJElSS1jASZIkaawlWZfk5iRTSTZPc/78JAeS3NA8fnkUcUrgLJSSJEkaY0mWAJcDzwH2ATuSTDZrbHZ7d1VdOPQApR72wEmSJGmcrQWmqmpPVd0DbAXWjzgmaUYWcJIkSRpnpwC3dO3va471en6SG5Nck+TU4YQm3Z8FnCRJkjS7vwVWVdUTgPcDV07XKMmmJDuT7Dxw4MBQA9T48B64GazafG1f7fZedvY8RyJJkqR5tB/o7lFb0Ry7V1Xd1rX758AbpnuhqtoCbAGYmJiowYYpddgDJ0mSpHG2A1id5LQky4ANwGR3gySP7No9B7hpiPFJh7EHTpIkSWOrqg4muRDYDiwBrqiqXUkuBXZW1STwa0nOAQ4CtwPnjyxgjT0LOEmSJI21qtoGbOs5dnHX9kXARcOOS5qOQyglSZIkqSUs4CRJkiSpJSzgJEmSJKklLOAkSZIkqSUs4CRJkiSpJSzgJEmSJKklLOAkSZIkqSUs4CRJkiSpJVzIe4ys2nxt3233Xnb2PEYiSZIk6WjYAydJkiRJLWEBJ0mSJEkt0VcBl2RdkpuTTCXZPM3545O8uzl/fZJVzfHjklyZ5NNJbkpy0WDDlyRJkqTxMWcBl2QJcDlwFrAG2JhkTU+zC4A7qup04E3A65vjLwCOr6rHAz8G/Mqh4k6SJEmSdGT66YFbC0xV1Z6qugfYCqzvabMeuLLZvgZ4VpIABZyQZCnwIOAe4BsDiVySJEmSxkw/BdwpwC1d+/uaY9O2qaqDwJ3ASXSKuX8DvgJ8CfjDqrr9GGOWJEmSpLE035OYrAW+CzwKOA14RZIf6G2UZFOSnUl2HjhwYJ5DkiRJkqR26qeA2w+c2rW/ojk2bZtmuOSJwG3Ai4C/r6rvVNWtwEeBid43qKotVTVRVRPLly8/8p9CkiRJksZAPwXcDmB1ktOSLAM2AJM9bSaB85rtc4EPVlXRGTb5TIAkJwBPAT4ziMAlSZIkadzMWcA197RdCGwHbgKurqpdSS5Nck7T7G3ASUmmgJcDh5YauBx4cJJddArBv6iqGwf9Q0iSJEnSOFjaT6Oq2gZs6zl2cdf23XSWDOh93l3THZckSZIkHbn5nsREkiRJkjQgFnCSJGmgkqxLcnOSqSSbZ2jzwiS7k+xK8q5hxyhJbdXXEEpJkqR+JFlC5x7459BZO3ZHksmq2t3VZjVwEfDUqrojySNGE60ktY89cJIkaZDWAlNVtaeq7gG2Aut72rwYuLyq7gBolhqSJPXBAk6SJA3SKcAtXfv7mmPdHgs8NslHk1yXZN10L5RkU5KdSXYeOHBgnsKVpHaxgJMkScO2FFgNnAlsBN6a5GG9japqS1VNVNXE8uXLhxyiJC1MFnCSJGmQ9gOndu2vaI512wdMVtV3quoLwGfpFHSSpDlYwEmSpEHaAaxOclqSZcAGYLKnzd/Q6X0jycl0hlTuGWaQktRWFnAaK05trTYyb9UmVXUQuBDYDtwEXF1Vu5JcmuScptl24LYku4EPAb9ZVbeNJmJJaheXEdDYcGprtZF5qzaqqm3Atp5jF3dtF/Dy5iFJOgL2wGmcOLW12si8laR51s9Ih6bd85NUkolhxid1s4DTOBnY1Nbg9NYaGqdkl6R51DXS4SxgDbAxyZpp2j0EeBlw/XAjlA5nAScdrq+prcHprbWgOCW7JB29fkY6ALwWeD1w9zCDk3pZwGmcOLW12si8laT5NedIhyRPAk6tqmtneyFHOmgYLOA0TpzaWm1k3krSCCV5APBG4BVztXWkg4bBAk5jw6mt1UbmrSTNu7lGOjwE+BHgw0n2Ak8BJp3IRKPiMgIaK05trTYybyVpXt070oFO4bYBeNGhk1V1J3Dyof0kHwZeWVU7hxynBNgDJ0mSpDHW50gHacGwB06SJEljba6RDj3HzxxGTNJM7IGTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSW6KuAS7Iuyc1JppJsnub88Une3Zy/PsmqrnNPSPKxJLuSfDrJAwcXviRJkiSNjzkLuCRLgMuBs4A1wMYka3qaXQDcUVWnA28CXt88dynwTuAlVfU44EzgOwOLXpIkSZLGSD89cGuBqaraU1X3AFuB9T1t1gNXNtvXAM9KEuC5wI1V9SmAqrqtqr47mNAlSZIkabz0U8CdAtzStb+vOTZtm6o6CNwJnAQ8Fqgk25N8IslvHXvIkiRJkjSelg7h9Z8GPBn4JvCBJB+vqg90N0qyCdgEsHLlynkOSZIkSZLaqZ8euP3AqV37K5pj07Zp7ns7EbiNTm/dR6rqa1X1TWAb8KTeN6iqLVU1UVUTy5cvP/KfQpIkSZLGQD8F3A5gdZLTkiwDNgCTPW0mgfOa7XOBD1ZVAduBxyf5vqawewawezChS5IkSdJ4mXMIZVUdTHIhnWJsCXBFVe1Kcimws6omgbcB70gyBdxOp8ijqu5I8kY6RWAB26rq2nn6WSRJkiRpUevrHriq2kZn+GP3sYu7tu8GXjDDc99JZykBSZIkSdIx6Gshb0mSJEnS6FnASZIkSVJLWMBJkiRJUktYwEmSJElSS1jASZIkSVJLWMBJkiRprCVZl+TmJFNJNk9z/iVJPp3khiT/mGTNKOKUwAJOkiRJYyzJEuBy4CxgDbBxmgLtXVX1+Kr6UeANwBuHHKZ0Lws4SZIkjbO1wFRV7amqe4CtwPruBlX1ja7dE4AaYnzSYfpayFuSJElapE4Bbuna3wec0dsoyUuBlwPLgGcOJzTp/uyBkyRJkuZQVZdX1WOA3wZ+d7o2STYl2Zlk54EDB4YboMaGBZwkSZLG2X7g1K79Fc2xmWwFfnq6E1W1paomqmpi+fLlAwxRuo8FnCRJksbZDmB1ktOSLAM2AJPdDZKs7to9G/jcEOOTDmMBJ0mSBmquKdm72j0/SSWZGGZ8UreqOghcCGwHbgKurqpdSS5Nck7T7MIku5LcQOc+uPNGFK7kJCaSJGlwuqZkfw6dySB2JJmsqt097R4CvAy4fvhRSoerqm3Atp5jF3dtv2zoQUkzsAdOkiQN0pxTsjdeC7weuHuYwUlS21nASZKkQZpuSvZTuhskeRJwalVdO8zAJGkxsIDTWPG+DLWReavFJMkDgDcCr+ijrVOyS1IPCziNja77Ms4C1gAbk6yZpp33ZWjBMG/VQnNNyf4Q4EeADyfZCzwFmJzuwoNTskvS/VnAaZx4X4bayLxV28w6JXtV3VlVJ1fVqqpaBVwHnFNVO0cTriS1iwWcxon3ZaiNzFu1Sp9TskuSjpLLCEiNrvsyzu+z/SZgE8DKlSvnLzBpFkeSt+ashmWuKdl7jp85jJgkabGwB07jZGD3ZYD3ZmhovJ9IkiTdywJO48T7MtRG5q0kSbqXBZzGhvdlqI3MW0mS1M174DRWvC9DbWTeSpKkQ+yBkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKklrCAkyRJkqSWsICTJEmSpJawgJMkSZKkluirgEuyLsnNSaaSbJ7m/PFJ3t2cvz7Jqp7zK5PcleSVgwlbkiRJksbPnAVckiXA5cBZwBpgY5I1Pc0uAO6oqtOBNwGv7zn/RuC9xx6uJEmSJI2vfnrg1gJTVbWnqu4BtgLre9qsB65stq8BnpUkAEl+GvgCsGswIUuSJEnSeOqngDsFuKVrf19zbNo2VXUQuBM4KcmDgd8GXnPsoUqSJEnSeJvvSUwuAd5UVXfN1ijJpiQ7k+w8cODAPIckSZIkSe20tI82+4FTu/ZXNMema7MvyVLgROA24Azg3CRvAB4GfC/J3VX15u4nV9UWYAvAxMREHc0PIkmSJEmLXT8F3A5gdZLT6BRqG4AX9bSZBM4DPgacC3ywqgr48UMNklwC3NVbvEmSJEmS+jNnAVdVB5NcCGwHlgBXVNWuJJcCO6tqEngb8I4kU8DtdIo8SZIkSdIA9dMDR1VtA7b1HLu4a/tu4AVzvMYlRxGfJEmSJKkx35OYSJIkSQtaknVJbk4ylWTzNOdfnmR3khuTfCDJo0cRpwQWcJIkSRpjSZYAlwNnAWuAjUnW9DT7JDBRVU+gs+bxG4YbpXQfCzhJkiSNs7XAVFXtqap7gK3A+u4GVfWhqvpms3sdnVnZpZGwgJMkSdI4OwW4pWt/X3NsJhcA753XiKRZ9DWJiSRJkjTukvw8MAE8Y4bzm4BNACtXrhxiZBon9sBJkiRpnO0HTu3aX9EcO0ySZwOvAs6pqm9P90JVtaWqJqpqYvny5fMSrGQBJ0mSpHG2A1id5LQky+isZzzZ3SDJE4G30Cnebh1BjNK9LOAkSZI0tqrqIHAhsB24Cbi6qnYluTTJOU2zPwAeDPxVkhuSTM7wctK88x44SZIkjbWq2gZs6zl2cdf2s4celDQDe+AkSZIkqSUs4CRJkiSpJSzgJEmSJKklLOAkSZIkqSUs4CRJkiSpJZyFUpIkSdKitmrztX233XvZ2fMYybGzB06SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawoW8JUnSQCVZB/wxsAT486q6rOf8y4FfBg4CB4BfqqovDj1QaZFbTItX6z72wEmSpIFJsgS4HDgLWANsTLKmp9kngYmqegJwDfCG4UYpSe1lAaexkmRdkpuTTCXZPM35lyfZneTGJB9I8uhRxCl1M2/VMmuBqaraU1X3AFuB9d0NqupDVfXNZvc6YMWQY5Sk1rKA09jwqrDayLxVC50C3NK1v685NpMLgPfOa0SStIhYwGmceFVYbWTeatFK8vPABPAHM5zflGRnkp0HDhwYbnCStEBZwGmceFVYbWTeqm32A6d27a9ojtShnuoAAAerSURBVB0mybOBVwHnVNW3p3uhqtpSVRNVNbF8+fJ5CVaS2sZZKKVpdF0VfsYsbTYBmwBWrlw5pMikmc2Vt+ashmQHsDrJaXQKtw3Ai7obJHki8BZgXVXdOvwQJam97IHTOBnYVWHwyrCGxt4MtUpVHQQuBLYDNwFXV9WuJJcmOadp9gfAg4G/SnJDkskRhStJrWMPnMaJV4XVRuatWqeqtgHbeo5d3LX97KEHJUmLhD1wGhteFVYbmbeSJKmbPXAaK14VVhuZt5Ik6ZC+euD6WET2+CTvbs5fn2RVc/w5ST6e5NPNf5852PAlSZIkaXzMWcD1uYjsBcAdVXU68Cbg9c3xrwE/VVWPB84D3jGowCVJkiRp3PTTAzfnIrLN/pXN9jXAs5Kkqj5ZVV9uju8CHpTk+EEELkmSJEnjpp8Crp9FZO9t09xwfydwUk+b5wOfmG1adkmSJEnSzIYyiUmSx9EZVvncGc67uKwkSZIkzaGfHrh+FpG9t02SpcCJwG3N/grgPcAvVNXnp3sDF5eVJEnSqPQxYd/Tk3wiycEk544iRumQfgq4exeRTbKMziKyvWsMTdKZpATgXOCDVVVJHgZcC2yuqo8OKmhJkiRpEPqcsO9LwPnAu4YbnXR/cxZwfS4i+zbgpCRTwMuBQ1cuLgROBy5uFpe9IckjBv5TSJIkSUdnzgn7qmpvVd0IfG8UAUrd+roHro9FZO8GXjDN814HvO4YY5QkSZLmy3QT9p1xNC/kvA4ahr4W8pYkSZI0O+d10DBYwEmSJGmc9TNhn7RgWMBJkiRpnPUzYZ+0YFjASZIkaWz1M2Ffkicn2Udnzoe3JNk1uog17oaykLckSZK0UPUxYd8OOkMrpZGzB06SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWsICTpIkSZJawgJOkiRJklrCAk6SJEmSWmLpqAOQJEmSZrJq87V9t9172dnzGIm0MNgDJ0mSJEktYQEnSZIkSS3hEEpJkiQdEYc1SqPTVw9cknVJbk4ylWTzNOePT/Lu5vz1SVZ1nbuoOX5zkp8YXOjSkTuWXJZGxbxV25izahtzVm0yZwGXZAlwOXAWsAbYmGRNT7MLgDuq6nTgTcDrm+euATYAjwPWAX/SvJ40dMeSy9KomLdqG3NWbWPOqm366YFbC0xV1Z6qugfYCqzvabMeuLLZvgZ4VpI0x7dW1ber6gvAVPN60igcSy5Lo2Leqm3MWbWNOatW6aeAOwW4pWt/X3Ns2jZVdRC4Ezipz+dKw3IsuSyNinmrtjFn1TbmrFplQUxikmQTsKnZvSvJzUfw9JOBr93vNY+iY3tYz5nH9/Kz6Hj00UVyhDHMQ94uNkebs2PInJ3D0f5+mSfmrTnbNgP7LIb598HRmOW9xipnF/q/Ux/8f84sOdtPAbcfOLVrf0VzbLo2+5IsBU4EbuvzuVTVFmBLH7HcT5KdVTVxNM9dbPws5nQsuXw/5u2x83Poy8Dy1pwdDD+LOZmzC4yfxZzM2QXGz2J2/Qyh3AGsTnJakmV0JiWZ7GkzCZzXbJ8LfLCqqjm+oZm55zRgNfBPgwldOmLHksvSqJi3ahtzVm1jzqpV5uyBq6qDSS4EtgNLgCuqaleSS4GdVTUJvA14R5Ip4HY6iU/T7mpgN3AQeGlVfXeefhZpVseSy9KomLdqG3NWbWPOqm3S9osHSTY13dVjz8+iPfy36vBzaA//re7jZ9EO/jvdx8+iHfx3uo+fxexaX8BJkiRJ0rjo5x44SZIkSdIC0OoCLsm6JDcnmUqyedTxjFKSvUk+neSGJDtHHY+mZ87ex5xtB3P2PuZsO5iz9zFn28O8vY95O7fWDqFMsgT4LPAcOgsu7gA2VtXukQY2Ikn2AhNV5Zo3C5Q5ezhzduEzZw9nzi585uzhzNl2MG8PZ97Orc09cGuBqaraU1X3AFuB9SOOSZqNOau2MWfVNuas2si81RFpcwF3CnBL1/6+5ti4KuB9ST6eZNOog9G0zNnDmbMLnzl7OHN24TNnD2fOtoN5ezjzdg5zrgOn1nhaVe1P8gjg/Uk+U1UfGXVQ0izMWbWNOau2MWfVRubtHNrcA7cfOLVrf0VzbCxV1f7mv7cC76HTHa+FxZztYs62gjnbxZxtBXO2iznbGuZtF/N2bm0u4HYAq5OclmQZsAGYHHFMI5HkhCQPObQNPBf459FGpWmYsw1ztjXM2YY52xrmbMOcbRXztmHe9qe1Qyir6mCSC4HtwBLgiqraNeKwRuXfAe9JAp1/03dV1d+PNiT1MmcPY862gDl7GHO2BczZw5izLWHeHsa87UNrlxGQJEmSpHHT5iGUkiRJkjRWLOAkSZIkqSUs4CRJkiSpJSzgJEmSJKklLOAkSZIkqSUs4CRJkiSpJSzgJEmSJKklLOAkSZIkqSX+Pwwzl2FQX4iUAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 1080x360 with 5 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "all_probs = softmax(all_outputs, axis=1)\n",
    "f, arr = plt.subplots(1, 5, figsize=(15, 5))\n",
    "\n",
    "for i in range(4):\n",
    "        arr[i+1].set_title(\"sample {}\".format(i))\n",
    "        arr[i+1].bar(classes, all_probs[:, :, i].mean(axis=0))\n",
    "arr[0].set_title(\"pred dist\")\n",
    "arr[0].bar(classes, all_probs.mean(axis=0).mean(axis=-1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Posterior Samples"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from torch.optim.lr_scheduler import CosineAnnealingLR"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "with torch.no_grad():\n",
    "    resample(model,0.1)\n",
    "opt = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)\n",
    "opt_sgld = SGLD(model.parameters(), lr=1e-3, noise_factor=1/np.sqrt(len(loaders['train'].dataset)))\n",
    "criterion = losses.cross_entropy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]\n",
      "100%|██████████| 5000/5000 [29:43<00:00,  2.80it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 {'loss': 145.5122833251953, 'accuracy': 17.4, 'stats': {}}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "all_outputs = []\n",
    "sigma = np.sqrt(1e1)\n",
    "epochs = 0\n",
    "sgld_epochs = 5000\n",
    "samples = 1\n",
    "\n",
    "scheduler_sgld = CosineAnnealingLR(opt_sgld, T_max=sgld_epochs)\n",
    "\n",
    "regularizer = lambda model: \\\n",
    "        sum([p.norm()**2 for p in model.parameters()]) / (2 * sigma**2 * len(loaders['train'].dataset))\n",
    "\n",
    "for i in range(samples):\n",
    "    with torch.no_grad():\n",
    "        resample(model, sigma=sigma)\n",
    "    for _ in tqdm.tqdm(range(epochs)):\n",
    "        res = utils.train_epoch(loaders['train'], model, criterion, opt, regularizer=regularizer)\n",
    "        \n",
    "    for _ in tqdm.tqdm(range(sgld_epochs)):\n",
    "        res = utils.train_epoch(loaders['train'], model, criterion, opt_sgld, regularizer=regularizer)\n",
    "        scheduler_sgld.step()\n",
    "    print(i, res)\n",
    "    with torch.no_grad():\n",
    "        outputs = model(x_test).data.cpu().numpy()[:, :, None]\n",
    "        all_outputs.append(outputs)\n",
    "all_outputs = np.dstack(all_outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<BarContainer object of 10 artists>"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUIAAAE/CAYAAAAzEcqDAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAATY0lEQVR4nO3df7Dd9V3n8edrk5K20AIbrrPbJJi4xDrB1m1N07pq7RRtw7Q2uyO4oauCw2ysml1/1Km07lKK3Z3iOKXOlHUaCy4LbQOb6kxGouiIszujFRMoFgNEb1MkSdttgBBKK6WB9/5xvrinZy/kkPO9Offm83zMMJzv5/v5nvM5EJ58v+fcc26qCklq2T+Z9gIkadoMoaTmGUJJzTOEkppnCCU1zxBKap4h1CkpyX9P8sFpr0OLgyGURiRZneTPknw9yQNJfnjaa9L8MoTS/+9TwGeB5cCvATuSzEx3SZpPhlC9S/KrSQ4l+WqSfUku6MY3JPlMkseSfCnJR5OcNnRcJfm5JH/XHfvrSf5Fkr9I8niSW5+dn+RNSQ4meV+Sh5M8mOTfPc+a3p7knu6x/yLJq59j3ncCrwXeX1X/UFWfBu4FfqzPf0ZaWAyhepXklcBW4HVV9TLgrcCD3e6ngV8CzgG+D7gA+LmRu3gr8L3AG4D3ANuAnwBWAd8NXDI0959197UCuBTY1j3+6JpeA9wA/AyDs7yPATuTLJvjKZwP7K+qrw6N/XU3rlOUIVTfngaWAeuSvKiqHqyqzwNU1V1V9ZdVdayqHmQQpB8aOf43qurxqtoL/A3wx1W1v6qOAn8IvGZk/n+uqm9U1f8CbgN+fI41bQE+VlV3VtXTVXUj8A0GsR11BnB0ZOwo8LIxn78WIUOoXlXVLPCLwFXAV5JsT/IKGFx2JvmDJF9O8jjwXxmc0Q37P0O3/2GO7TOGto9U1deGtv8eeMUcy/p24N3dZfFjSR5jcIY519wngJePjL0c+Oocc3WKMITqXVV9sqp+gEGACrim2/XbwAPA2qp6OfA+IBM81NlJTh/aPhf44hzzDgD/parOGvrrpVX1qTnm7gW+I8nwGeD3dOM6RRlC9SrJK5O8uXv97UkGZ3HPdLtfBjwOPJHku4Cf7eEhP5DktCQ/CLwd+J9zzPkd4F1JXp+B05O8bSR2AFTV3wL3AO9P8uIk/wZ4NfDpHtaqBcoQqm/LgA8BDwNfBr4NeG+371eAdzK4zPwd4JYJH+vLwBEGZ4GfAN5VVQ+MTqqqPcC/Bz7azZ8FLnue+90MrO/mfgi4qKoOT7hWLWDxi1m1GCV5E3BzVa2c9lq0+HlGKKl5hlBS87w0ltQ8zwglNc8QSmre0mkvYNQ555xTq1evnvYyJJ1i7rrrroeras5vEVpwIVy9ejV79uyZ9jIknWKS/P1z7fPSWFLzDKGk5hlCSc0zhJKaZwglNc8QSmqeIZTUPEMoqXmGUFLzDKGk5hlCSc1bcJ81Fqy+4rbe7/PBD72t9/uUThWeEUpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOaNFcIkG5PsSzKb5Io59r8xyd1JjiW5aGj8Xyb5TJK9ST6X5N/2uXhJ6sNxQ5hkCXAdcCGwDrgkybqRaQ8BlwGfHBn/OvBTVXU+sBH4SJKzJl20JPVpnO8j3ADMVtV+gCTbgU3Afc9OqKoHu33PDB9YVX87dPuLSb4CzACPTbxySerJOJfGK4ADQ9sHu7EXJMkG4DTg8y/0WEmaTyflzZIk/xy4Cfjpqnpmjv1bkuxJsufw4cMnY0mS9I/GCeEhYNXQ9spubCxJXg7cBvxaVf3lXHOqaltVra+q9TMzM+PetST1YpwQ7gbWJlmT5DRgM7BznDvv5v8+8D+qaseJL1OS5s9x3yypqmNJtgK3A0uAG6pqb5KrgT1VtTPJ6xgE72zgR5N8oHun+MeBNwLLk1zW3eVlVXXPfDwZSQvTQv+FZGP9Fruq2gXsGhm7cuj2bgaXzKPH3QzcPOEaJWle+ckSSc0zhJKaZwglNc8QSmqeIZTUPEMoqXmGUFLzDKGk5hlCSc0zhJKaZwglNc8QSmreWF+6IE2i728e6fNbR16ohf4tKjoxnhFKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNS8sb6qP8lG4LeAJcDHq+pDI/vfCHwEeDWwuap2DO27FPhP3eYHq+rGPhYuaXL+6oGB454RJlkCXAdcCKwDLkmybmTaQ8BlwCdHjv2nwPuB1wMbgPcnOXvyZUtSf8a5NN4AzFbV/qp6CtgObBqeUFUPVtXngGdGjn0r8CdV9WhVHQH+BNjYw7olqTfjhHAFcGBo+2A3No5JjpWkk2JBvFmSZEuSPUn2HD58eNrLkdSYcUJ4CFg1tL2yGxvHWMdW1baqWl9V62dmZsa8a0nqxzgh3A2sTbImyWnAZmDnmPd/O/CWJGd3b5K8pRuTpAXjuCGsqmPAVgYBux+4tar2Jrk6yTsAkrwuyUHgYuBjSfZ2xz4K/DqDmO4Gru7GJGnBGOvnCKtqF7BrZOzKodu7GVz2znXsDcANE6xRkubVgnizRJKmyRBKap4hlNQ8QyipeYZQUvPGetdYWgz6/iaVxfgtKjoxnhFKap5nhA3zDEoa8IxQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkprnL2+SFqC+f7EW+Mu1no9nhJKaZwglNW+sECbZmGRfktkkV8yxf1mSW7r9dyZZ3Y2/KMmNSe5Ncn+S9/a7fEma3HFDmGQJcB1wIbAOuCTJupFplwNHquo84Frgmm78YmBZVb0K+F7gZ56NpCQtFOO8WbIBmK2q/QBJtgObgPuG5mwCrupu7wA+miRAAacnWQq8BHgKeLyfpZ98voAtnZrGuTReARwY2j7Yjc05p6qOAUeB5Qyi+DXgS8BDwG9W1aMTrlmSejXfb5ZsAJ4GXgGsAd6d5DtGJyXZkmRPkj2HDx+e5yVJ0rcaJ4SHgFVD2yu7sTnndJfBZwKPAO8E/qiqvllVXwH+HFg/+gBVta2q1lfV+pmZmRf+LCRpAuOEcDewNsmaJKcBm4GdI3N2Apd2ty8C7qiqYnA5/GaAJKcDbwAe6GPhktSX44awe81vK3A7cD9wa1XtTXJ1knd0064HlieZBX4ZePZHbK4Dzkiyl0FQf7eqPtf3k5CkSYz1Ebuq2gXsGhm7cuj2kwx+VGb0uCfmGpekhcRPlkhqniGU1DxDKKl5hlBS8wyhpOYZQknNOyW+odovQ5A0Cc8IJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmjdWCJNsTLIvyWySK+bYvyzJLd3+O5OsHtr36iSfSbI3yb1JXtzf8iVpcscNYZIlwHXAhcA64JIk60amXQ4cqarzgGuBa7pjlwI3A++qqvOBNwHf7G31ktSDcc4INwCzVbW/qp4CtgObRuZsAm7sbu8ALkgS4C3A56rqrwGq6pGqerqfpUtSP8YJ4QrgwND2wW5szjlVdQw4CiwHvhOoJLcnuTvJeyZfsiT1a+lJuP8fAF4HfB340yR3VdWfDk9KsgXYAnDuuefO85Ik6VuNc0Z4CFg1tL2yG5tzTve64JnAIwzOHv93VT1cVV8HdgGvHX2AqtpWVeurav3MzMwLfxaSNIFxQrgbWJtkTZLTgM3AzpE5O4FLu9sXAXdUVQG3A69K8tIukD8E3NfP0iWpH8e9NK6qY0m2MojaEuCGqtqb5GpgT1XtBK4HbkoyCzzKIJZU1ZEkH2YQ0wJ2VdVt8/RcJOmEjPUaYVXtYnBZOzx25dDtJ4GLn+PYmxn8CI0kLUh+skRS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5hlBS8wyhpOYZQknNM4SSmmcIJTXPEEpqniGU1DxDKKl5Y4UwycYk+5LMJrlijv3LktzS7b8zyeqR/ecmeSLJr/SzbEnqz3FDmGQJcB1wIbAOuCTJupFplwNHquo84FrgmpH9Hwb+cPLlSlL/xjkj3ADMVtX+qnoK2A5sGpmzCbixu70DuCBJAJL8a+ALwN5+lixJ/RonhCuAA0PbB7uxOedU1THgKLA8yRnArwIfeL4HSLIlyZ4kew4fPjzu2iWpF/P9ZslVwLVV9cTzTaqqbVW1vqrWz8zMzPOSJOlbLR1jziFg1dD2ym5srjkHkywFzgQeAV4PXJTkN4CzgGeSPFlVH5145ZLUk3FCuBtYm2QNg+BtBt45MmcncCnwGeAi4I6qKuAHn52Q5CrgCSMoaaE5bgir6liSrcDtwBLghqram+RqYE9V7QSuB25KMgs8yiCWkrQojHNGSFXtAnaNjF05dPtJ4OLj3MdVJ7A+SZp3frJEUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNS8sUKYZGOSfUlmk1wxx/5lSW7p9t+ZZHU3/iNJ7kpyb/f3N/e7fEma3HFDmGQJcB1wIbAOuCTJupFplwNHquo84Frgmm78YeBHq+pVwKXATX0tXJL6Ms4Z4QZgtqr2V9VTwHZg08icTcCN3e0dwAVJUlWfraovduN7gZckWdbHwiWpL+OEcAVwYGj7YDc255yqOgYcBZaPzPkx4O6q+saJLVWS5sfSk/EgSc5ncLn8lufYvwXYAnDuueeejCVJ0j8a54zwELBqaHtlNzbnnCRLgTOBR7rtlcDvAz9VVZ+f6wGqaltVra+q9TMzMy/sGUjShMYJ4W5gbZI1SU4DNgM7R+bsZPBmCMBFwB1VVUnOAm4DrqiqP+9r0ZLUp+OGsHvNbytwO3A/cGtV7U1ydZJ3dNOuB5YnmQV+GXj2R2y2AucBVya5p/vr23p/FpI0gbFeI6yqXcCukbErh24/CVw8x3EfBD444RolaV75yRJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTmGUJJzTOEkppnCCU1zxBKap4hlNQ8QyipeYZQUvPGCmGSjUn2JZlNcsUc+5cluaXbf2eS1UP73tuN70vy1v6WLkn9OG4IkywBrgMuBNYBlyRZNzLtcuBIVZ0HXAtc0x27DtgMnA9sBP5bd3+StGCMc0a4AZitqv1V9RSwHdg0MmcTcGN3ewdwQZJ049ur6htV9QVgtrs/SVowxgnhCuDA0PbBbmzOOVV1DDgKLB/zWEmaqqXTXgBAki3Alm7ziST75umhzgEeHmtN18zTCvp9nAX1fHp4jLGezyL5dwM+n3l1Ao/z7c+1Y5wQHgJWDW2v7MbmmnMwyVLgTOCRMY+lqrYB28ZYy0SS7Kmq9fP9OCeLz2dh8/ksHuNcGu8G1iZZk+Q0Bm9+7ByZsxO4tLt9EXBHVVU3vrl7V3kNsBb4q36WLkn9OO4ZYVUdS7IVuB1YAtxQVXuTXA3sqaqdwPXATUlmgUcZxJJu3q3AfcAx4Oer6ul5ei6SdEIyOHFrQ5It3WX4KcHns7D5fBaPpkIoSXPxI3aSmtdMCI/3McHFJMmqJH+W5L4ke5P8wrTXNKkkS5J8NskfTHstk0pyVpIdSR5Icn+S75v2miaR5Je6P2d/k+RTSV487TX1rYkQjvkxwcXkGPDuqloHvAH4+UX+fAB+Abh/2ovoyW8Bf1RV3wV8D4v4eSVZAfxHYH1VfTeDN0w3T3dV/WsihIz3McFFo6q+VFV3d7e/yuA/tEX7iZ0kK4G3AR+f9lomleRM4I0MfpKCqnqqqh6b7qomthR4Sfczwi8Fvjjl9fSulRCesh/1677p5zXAndNdyUQ+ArwHeGbaC+nBGuAw8Lvdpf7Hk5w+7UWdqKo6BPwm8BDwJeBoVf3xdFfVv1ZCeEpKcgbwaeAXq+rxaa/nRCR5O/CVqrpr2mvpyVLgtcBvV9VrgK8Bi/Y16SRnM7h6WgO8Ajg9yU9Md1X9ayWEY33UbzFJ8iIGEfxEVf3etNczge8H3pHkQQYvWbw5yc3TXdJEDgIHq+rZM/QdDMK4WP0w8IWqOlxV3wR+D/hXU15T71oJ4TgfE1w0uq84ux64v6o+PO31TKKq3ltVK6tqNYN/L3dU1aI946iqLwMHkryyG7qAwSerFquHgDckeWn35+4CFvGbP89lQXz7zHx7ro8JTnlZk/h+4CeBe5Pc0429r6p2TXFN+n/+A/CJ7n+6+4GfnvJ6TlhV3ZlkB3A3g59W+Cwn4QtSTjY/WSKpea1cGkvSczKEkppnCCU1zxBKap4hlNQ8QyipeYZQUvMMoaTm/V/jIFduImCnegAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 360x360 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "all_probs = softmax(all_outputs, axis=1)\n",
    "f, arr = plt.subplots(1, 1, figsize=(5, 5))\n",
    "\n",
    "arr.set_title(\"sample {}\".format(i))\n",
    "arr.bar(classes, all_probs[:, :, i].mean(axis=0))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py37",
   "language": "python",
   "name": "py37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
